import numpy as np
import mpctools as mpc
import scipy.linalg.interpolative as sli

def RANK_pzc(M):
    r = sli.estimate_rank(M, eps=1e-6) # eps越小秩相对越大; eps越小计算更快
    return r

alpha = 0


def SenMatrix(F_rk4,H,x,y,u,size):
    Nx = x.shape[1]
    Ny = y.shape[1]
    Nu = u.shape[1]
    Sxx = np.zeros([size+1,Nx,Nx])
    Sxx[0,:,:] = np.eye(Nx)
    Syx = np.zeros([size+1,Ny,Nx])
    
    dFdx = np.zeros([size,Nx,Nx])
    dHdx = np.zeros([size+1,Ny,Nx])
        
    x_sen = np.zeros((size+1,Nx))
    u_sen = np.zeros((size,Nu))
    
    x_sen = x
    u_sen = u
    
    for i in range(size):
        # Sensitivity matrix calculation
       JF = mpc.util.getLinearizedModel(F_rk4, [x_sen[i,:],u_sen[i,:]],
                                         ["A","B"],None)
       dFdx[i,:,:] = JF["A"]
        
       JH = mpc.util.getLinearizedModel(H, [x_sen[i,:]],["A"],None)
       dHdx[i,:,:] = JH["A"]
    
       Sxx[i+1,:,:] = mpc.mtimes(dFdx[i,:,:],Sxx[i,:,:])
       Syx[i,:,:] = mpc.mtimes(dHdx[i,:,:],Sxx[i,:,:])
    
    if size == 0:
        JH = mpc.util.getLinearizedModel(H, [x_sen[0,:]], ["A"],None)
        dHdx[0,:,:] = JH["A"]
        Syx[0,:,:] = mpc.mtimes(dHdx[0,:,:],Sxx[0,:,:])
    else:
        JH = mpc.util.getLinearizedModel(H, [x_sen[i+1,:]], ["A"],None)
        dHdx[i+1,:,:] = JH["A"]
        Syx[i+1,:,:] = mpc.mtimes(dHdx[i+1,:,:],Sxx[i+1,:,:])

    
    # Original sensitivity matrix
    for i in range(size+1):
        # Concatenate sensitivity matrixes
        if i==0:
            Sall = Syx[i,:,:]
        else:
            Sall = np.concatenate((Sall,Syx[i,:,:]), axis=0)
    return Sall



def SensAnal(Sall,Nx,Ny,sigma_w,sigma_v,size,rank_xp): # size is the value of window size in FIE and MHE
    rank = RANK_pzc(Sall)

    rank_S = np.zeros([1,1],dtype=int)
    rank_S[0,0] = rank
    # Variable selection procudure
    Rl = np.zeros([Nx+1,(size+1)*Ny,Nx])
    Zl = np.zeros([Nx,(size+1)*Ny,Nx])
    SumCol = np.zeros([Nx,Nx])
    Flag = np.zeros([Nx,1],dtype=int)
    
    Rl[0,:,:] = Sall
    
    Xl = np.zeros([(size+1)*Ny,Nx])
    
    degalpha =  alpha*np.sqrt(sigma_w**2+sigma_v**2) #np.sqrt((size+1)*Ny) *
#    print('degalpha',degalpha)
    if rank_xp:
    #'''
    # Rank all
        for i in range(rank):
            for j in range(Nx):
                for k in range((size+1)*Ny):
                    SumCol[i,j] = SumCol[i,j] + Rl[i,k,j]**2
                SumCol[i,j] = np.sqrt(SumCol[i,j])
            Flag[i,0] = 0
            for i1 in range(Nx):
                if SumCol[i,Flag[i,0]] < SumCol[i,i1]:
                   Flag[i,0] = i1
                            
    # A prescribed value for sensitivity to terminate
 
            if SumCol[i,Flag[i,0]]  < degalpha:
                rank = i
                break
            else:
                Xl[:,i] = Sall[:,Flag[i,0]]
                                
    # select all variables or singular Xl: Terminate the selection
            if i == rank-1:
                break
            else:
                 rank_xx = RANK_pzc(mpc.mtimes(Xl[:,0:i+1].T,Xl[:,0:i+1]))
                 if rank_xx == (i+1):
                     Zl[i,:,:] = mpc.mtimes(mpc.mtimes(mpc.mtimes(Xl[:,0:i+1],
                                            np.linalg.inv(mpc.mtimes(Xl[:,0:i+1].T,Xl[:,0:i+1]))),
                                            Xl[:,0:i+1].T),Sall)
                     Rl[i+1,:,:] = Sall-Zl[i,:,:]
    # Eliminate residual non-zero values
                     for i1 in range(i+1):
                         for i2 in range((size+1)*Ny):
                             Rl[i+1,i2,Flag[i1,0]] = 0
                 else:
                     rank = i+1
                     break
    maxcol = np.amax(SumCol,axis=1)
#    print('maxcol =',maxcol)
#    Js = functools.reduce(operator.mul, maxcol[0:rank], 1)
#    Js = np.sum(maxcol)
    Js = rank*np.sum(maxcol)
#    print('Js =',Js
    print(rank)
    Result = np.concatenate((rank_S,Js),axis = None)  #,Flag,rank_S
    return Result